※ 作成中の記事です

はじめに

Score Matching は, 2005 年に Hyvarinen により提案された確率モデルの推定方法です.

この方法の利点は, モデル分布からのサンプリングをおこなうことなく, 確率モデルを推定できることです.

問題設定

観測データ ${\boldsymbol x} \in \mathbb{R}^n$ の従う確率密度 $\mathbb{P}_{data}({\boldsymbol x})$ を仮定します.

そして, モデル分布 $\mathbb{P}_{\theta}(x)$ のパラメータ $\theta$ を調節して, $\mathbb{P}_{data}({\boldsymbol x})$ へと近づけることを考えます.

ここで生じる問題は, $\mathbb{P}_{\theta}({\boldsymbol x})$ を直接的に推定しようとすると, 計算困難な正規化のための関数である分配関数 $Z(\theta)$ を計算しなければならないということです.

$$ Z(\theta) = \int \mathbb{P}_{\theta}(x) dx$$

そこで, 分配関数の計算を無視するためにエネルギー関数 $E_{\theta}(\boldsymbol{x})$ を導入します,

$$ \mathbb{P}_{\theta}({\boldsymbol x})=\frac{1}{Z(\theta)}\exp(-E_{\theta}(\boldsymbol{x})) $$

Score Matching では, エネルギー関数の勾配を用いて モデル分布とデータ分布の距離を近づけます.

スコア関数

モデルの確率密度関数についてのスコア関数 $\psi_\theta ({\boldsymbol x})$ を次式で定義します.

$$ \begin{aligned} \psi_\theta ({\boldsymbol x}) = - \nabla_{\boldsymbol x} E_{\theta}({\boldsymbol x}) \end{aligned} $$

ここで重要なのは, スコア関数が $Z(\theta)$ に依存しないことです.

同様に, 観測データのスコア関数を $\psi_{data}({x})=\nabla_{x} \log p_d({x})$

誤差関数

$$ \begin{aligned} J(\theta) = \frac{1}{2}\int \mathbb{P}_{data}({\boldsymbol x}) \parallel \psi_\theta({\boldsymbol x})-\psi_{data}({\boldsymbol x}) \parallel^2 d{\boldsymbol x} \end{aligned} $$

実装

使用するモジュール

import numpy as np
import torch
from torch import nn

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
import tqdm.notebook as tq

from sklearn.datasets import make_circles
from sklearn.preprocessing import MinMaxScaler

from matplotlib import rc
rc('animation', html='jshtml')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
<torch._C.Generator at 0x7fca3b4c3450>

データの準備

data, _ = make_circles(n_samples=5000, factor=0.5, noise=0.015, random_state=0)

# 前処理
scaler = MinMaxScaler()
train_data = scaler.fit_transform(data)

# プロット
plt.scatter(train_data[:,0], train_data[:,1])

# 学習用のデータローダ
train_tensor = torch.from_numpy(train_data.astype("float32"))
train_dataset = torch.utils.data.TensorDataset(train_tensor)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# テストデータ
n = 50
arr = np.linspace(0.0, 1.0,n)
m1,m2 = np.meshgrid(arr,arr)
test_data = np.hstack([m1.reshape(n**2,1), m2.reshape(n**2,1)])
test_tensor = torch.from_numpy(test_data.astype("float32")).to(device)

モデル構成

class NN(nn.Module):
  def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(2, 64),
        nn.SiLU(),
        nn.Linear(64, 64),
        nn.SiLU(),
        nn.Linear(64, 64)
    )

  def forward(self, x):
    output = self.net(x).sum(axis=1)
    return output

誤差関数

def loss_func(x, energy):
  
    # 一階微分
    first_diff, = torch.autograd.grad(-energy, x, 
                                      grad_outputs = torch.ones_like(energy),
                                      create_graph = True)

    # 二階微分
    second_diff = torch.autograd.grad(first_diff, x,
                                      grad_outputs = torch.ones_like(first_diff),
                                      create_graph = True)[0]
  
    return (first_diff**2 + 2 * second_diff).sum(axis=1).mean()
model = NN().to(device)

# 学習のログ
loss_log = []
energy_log = []
data_energy_log = []

# 最適化法
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

推定実行

epoch = 3000
for e in range(1, epoch+1):

    size = len(train_loader.dataset)
    loss_sum = 0
    
    for sample, in train_loader:
      
        # 入力データのセッティング
        sample = sample.to(device)
        sample.requires_grad = True
        energy = model(sample)
        loss = loss_func(sample, energy)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_sum += loss.data

    loss_log.append(loss_sum/size)

    if e % 10 == 0:
        print("epoch : {}/{}, loss : {}".format(e, epoch, sum(loss_log[-100:])/100))

        model.eval()
        energy = model(test_tensor).to('cpu').detach().numpy()
        energy_log.append(energy)

        data_energy = model(sample).to('cpu').detach().numpy()
        data_energy_log.append(data_energy)
epoch : 10/3000, loss : -0.08338470011949539
epoch : 20/3000, loss : -0.2878255546092987
epoch : 30/3000, loss : -2.2863991260528564
epoch : 40/3000, loss : -8.685164451599121
epoch : 50/3000, loss : -17.223005294799805
epoch : 60/3000, loss : -26.482295989990234
epoch : 70/3000, loss : -36.1797981262207
epoch : 80/3000, loss : -46.20315170288086
epoch : 90/3000, loss : -56.48970413208008
epoch : 100/3000, loss : -66.90300750732422
epoch : 110/3000, loss : -77.89024353027344
epoch : 120/3000, loss : -89.7041015625
epoch : 130/3000, loss : -100.77731323242188
epoch : 140/3000, loss : -108.42050170898438
epoch : 150/3000, loss : -114.95436096191406
epoch : 160/3000, loss : -121.60564422607422
epoch : 170/3000, loss : -128.74969482421875
epoch : 180/3000, loss : -136.35435485839844
epoch : 190/3000, loss : -144.29281616210938
epoch : 200/3000, loss : -152.6133575439453
epoch : 210/3000, loss : -161.01644897460938
epoch : 220/3000, loss : -168.96641540527344
epoch : 230/3000, loss : -176.6702880859375
epoch : 240/3000, loss : -184.0133819580078
epoch : 250/3000, loss : -190.553955078125
epoch : 260/3000, loss : -196.8241424560547
epoch : 270/3000, loss : -202.8232421875
epoch : 280/3000, loss : -208.55931091308594
epoch : 290/3000, loss : -214.24530029296875
epoch : 300/3000, loss : -219.56326293945312
epoch : 310/3000, loss : -224.47601318359375
epoch : 320/3000, loss : -229.24676513671875
epoch : 330/3000, loss : -233.6777801513672
epoch : 340/3000, loss : -237.77340698242188
epoch : 350/3000, loss : -242.34043884277344
epoch : 360/3000, loss : -247.14686584472656
epoch : 370/3000, loss : -251.833740234375
epoch : 380/3000, loss : -256.8088073730469
epoch : 390/3000, loss : -261.8515319824219
epoch : 400/3000, loss : -267.37298583984375
epoch : 410/3000, loss : -273.50933837890625
epoch : 420/3000, loss : -279.8847351074219
epoch : 430/3000, loss : -285.9499816894531
epoch : 440/3000, loss : -292.4321594238281
epoch : 450/3000, loss : -298.9040832519531
epoch : 460/3000, loss : -304.71051025390625
epoch : 470/3000, loss : -310.6063232421875
epoch : 480/3000, loss : -315.641357421875
epoch : 490/3000, loss : -320.2566223144531
epoch : 500/3000, loss : -325.0799255371094
epoch : 510/3000, loss : -329.0292053222656
epoch : 520/3000, loss : -332.0834045410156
epoch : 530/3000, loss : -335.9928894042969
epoch : 540/3000, loss : -339.1477966308594
epoch : 550/3000, loss : -342.2123718261719
epoch : 560/3000, loss : -345.1474914550781
epoch : 570/3000, loss : -347.32763671875
epoch : 580/3000, loss : -349.9559020996094
epoch : 590/3000, loss : -352.48583984375
epoch : 600/3000, loss : -354.5611572265625
epoch : 610/3000, loss : -357.0585021972656
epoch : 620/3000, loss : -359.5959777832031
epoch : 630/3000, loss : -361.72979736328125
epoch : 640/3000, loss : -363.7220153808594
epoch : 650/3000, loss : -365.0343017578125
epoch : 660/3000, loss : -366.81634521484375
epoch : 670/3000, loss : -368.7081298828125
epoch : 680/3000, loss : -369.9951477050781
epoch : 690/3000, loss : -371.4161376953125
epoch : 700/3000, loss : -372.7527160644531
epoch : 710/3000, loss : -373.7648620605469
epoch : 720/3000, loss : -375.1389465332031
epoch : 730/3000, loss : -375.8211975097656
epoch : 740/3000, loss : -376.9664611816406
epoch : 750/3000, loss : -378.22308349609375
epoch : 760/3000, loss : -379.067138671875
epoch : 770/3000, loss : -380.13958740234375
epoch : 780/3000, loss : -381.382568359375
epoch : 790/3000, loss : -382.7099914550781
epoch : 800/3000, loss : -383.67578125
epoch : 810/3000, loss : -384.54248046875
epoch : 820/3000, loss : -385.8087158203125
epoch : 830/3000, loss : -386.5153503417969
epoch : 840/3000, loss : -387.2934265136719
epoch : 850/3000, loss : -388.3346862792969
epoch : 860/3000, loss : -389.4114074707031
epoch : 870/3000, loss : -389.86865234375
epoch : 880/3000, loss : -390.6358947753906
epoch : 890/3000, loss : -391.1383056640625
epoch : 900/3000, loss : -391.5638122558594
epoch : 910/3000, loss : -392.10052490234375
epoch : 920/3000, loss : -392.58154296875
epoch : 930/3000, loss : -393.14556884765625
epoch : 940/3000, loss : -393.28466796875
epoch : 950/3000, loss : -393.30908203125
epoch : 960/3000, loss : -393.4814453125
epoch : 970/3000, loss : -393.84027099609375
epoch : 980/3000, loss : -394.3819580078125
epoch : 990/3000, loss : -394.4949645996094
epoch : 1000/3000, loss : -395.07342529296875
epoch : 1010/3000, loss : -395.5217590332031
epoch : 1020/3000, loss : -395.46929931640625
epoch : 1030/3000, loss : -395.9930419921875
epoch : 1040/3000, loss : -396.85107421875
epoch : 1050/3000, loss : -397.3284606933594
epoch : 1060/3000, loss : -397.2408447265625
epoch : 1070/3000, loss : -397.5451354980469
epoch : 1080/3000, loss : -397.56207275390625
epoch : 1090/3000, loss : -398.43603515625
epoch : 1100/3000, loss : -398.6026611328125
epoch : 1110/3000, loss : -398.7939758300781
epoch : 1120/3000, loss : -398.6358642578125
epoch : 1130/3000, loss : -398.9263916015625
epoch : 1140/3000, loss : -399.3721923828125
epoch : 1150/3000, loss : -399.4662170410156
epoch : 1160/3000, loss : -399.54437255859375
epoch : 1170/3000, loss : -400.27301025390625
epoch : 1180/3000, loss : -400.5166320800781
epoch : 1190/3000, loss : -400.2506103515625
epoch : 1200/3000, loss : -400.63677978515625
epoch : 1210/3000, loss : -400.7822265625
epoch : 1220/3000, loss : -401.3302001953125
epoch : 1230/3000, loss : -401.67315673828125
epoch : 1240/3000, loss : -401.40191650390625
epoch : 1250/3000, loss : -401.7897644042969
epoch : 1260/3000, loss : -402.6105041503906
epoch : 1270/3000, loss : -402.5826110839844
epoch : 1280/3000, loss : -402.5304870605469
epoch : 1290/3000, loss : -402.8447265625
epoch : 1300/3000, loss : -402.9979248046875
epoch : 1310/3000, loss : -403.4393005371094
epoch : 1320/3000, loss : -403.5967102050781
epoch : 1330/3000, loss : -404.0745544433594
epoch : 1340/3000, loss : -404.40130615234375
epoch : 1350/3000, loss : -404.5466003417969
epoch : 1360/3000, loss : -404.68170166015625
epoch : 1370/3000, loss : -404.6399230957031
epoch : 1380/3000, loss : -405.3765869140625
epoch : 1390/3000, loss : -405.7052001953125
epoch : 1400/3000, loss : -405.7744140625
epoch : 1410/3000, loss : -405.5479736328125
epoch : 1420/3000, loss : -406.2796325683594
epoch : 1430/3000, loss : -406.1007385253906
epoch : 1440/3000, loss : -406.0089416503906
epoch : 1450/3000, loss : -405.7023010253906
epoch : 1460/3000, loss : -405.6918029785156
epoch : 1470/3000, loss : -405.937255859375
epoch : 1480/3000, loss : -406.0537414550781
epoch : 1490/3000, loss : -405.9806213378906
epoch : 1500/3000, loss : -406.0086669921875
epoch : 1510/3000, loss : -406.64605712890625
epoch : 1520/3000, loss : -406.7516784667969
epoch : 1530/3000, loss : -406.89971923828125
epoch : 1540/3000, loss : -407.6455078125
epoch : 1550/3000, loss : -408.3357849121094
epoch : 1560/3000, loss : -408.674560546875
epoch : 1570/3000, loss : -409.0245361328125
epoch : 1580/3000, loss : -408.9248352050781
epoch : 1590/3000, loss : -409.10797119140625
epoch : 1600/3000, loss : -409.3531799316406
epoch : 1610/3000, loss : -409.5594482421875
epoch : 1620/3000, loss : -409.54327392578125
epoch : 1630/3000, loss : -409.6358642578125
epoch : 1640/3000, loss : -409.6973876953125
epoch : 1650/3000, loss : -409.92791748046875
epoch : 1660/3000, loss : -410.306396484375
epoch : 1670/3000, loss : -410.47509765625
epoch : 1680/3000, loss : -410.92730712890625
epoch : 1690/3000, loss : -410.8974609375
epoch : 1700/3000, loss : -410.89019775390625
epoch : 1710/3000, loss : -410.94677734375
epoch : 1720/3000, loss : -410.74993896484375
epoch : 1730/3000, loss : -410.98974609375
epoch : 1740/3000, loss : -411.1341552734375
epoch : 1750/3000, loss : -411.34613037109375
epoch : 1760/3000, loss : -411.2726745605469
epoch : 1770/3000, loss : -410.98187255859375
epoch : 1780/3000, loss : -410.9552917480469
epoch : 1790/3000, loss : -411.7115478515625
epoch : 1800/3000, loss : -412.2698974609375
epoch : 1810/3000, loss : -412.1260986328125
epoch : 1820/3000, loss : -412.3719482421875
epoch : 1830/3000, loss : -412.3901062011719
epoch : 1840/3000, loss : -412.4488220214844
epoch : 1850/3000, loss : -412.3614501953125
epoch : 1860/3000, loss : -411.8584289550781
epoch : 1870/3000, loss : -412.77740478515625
epoch : 1880/3000, loss : -413.2086181640625
epoch : 1890/3000, loss : -412.680419921875
epoch : 1900/3000, loss : -412.8074951171875
epoch : 1910/3000, loss : -413.00616455078125
epoch : 1920/3000, loss : -413.2409362792969
epoch : 1930/3000, loss : -413.5752258300781
epoch : 1940/3000, loss : -413.72064208984375
epoch : 1950/3000, loss : -413.9647216796875
epoch : 1960/3000, loss : -414.7547607421875
epoch : 1970/3000, loss : -414.3896484375
epoch : 1980/3000, loss : -413.8370361328125
epoch : 1990/3000, loss : -414.6212463378906
epoch : 2000/3000, loss : -414.61126708984375
epoch : 2010/3000, loss : -414.7867431640625
epoch : 2020/3000, loss : -414.51971435546875
epoch : 2030/3000, loss : -414.26593017578125
epoch : 2040/3000, loss : -414.1272277832031
epoch : 2050/3000, loss : -414.3841552734375
epoch : 2060/3000, loss : -414.2296447753906
epoch : 2070/3000, loss : -414.61236572265625
epoch : 2080/3000, loss : -414.80706787109375
epoch : 2090/3000, loss : -414.38885498046875
epoch : 2100/3000, loss : -414.3739013671875
epoch : 2110/3000, loss : -414.6207275390625
epoch : 2120/3000, loss : -415.0843505859375
epoch : 2130/3000, loss : -415.1674499511719
epoch : 2140/3000, loss : -415.3377685546875
epoch : 2150/3000, loss : -415.43182373046875
epoch : 2160/3000, loss : -415.482421875
epoch : 2170/3000, loss : -414.9602966308594
epoch : 2180/3000, loss : -415.57763671875
epoch : 2190/3000, loss : -415.8824768066406
epoch : 2200/3000, loss : -415.99639892578125
epoch : 2210/3000, loss : -415.6977233886719
epoch : 2220/3000, loss : -415.76580810546875
epoch : 2230/3000, loss : -416.3381042480469
epoch : 2240/3000, loss : -416.1217956542969
epoch : 2250/3000, loss : -416.0555114746094
epoch : 2260/3000, loss : -416.7529602050781
epoch : 2270/3000, loss : -417.0820617675781
epoch : 2280/3000, loss : -416.82379150390625
epoch : 2290/3000, loss : -416.4236755371094
epoch : 2300/3000, loss : -416.6850280761719
epoch : 2310/3000, loss : -416.8199768066406
epoch : 2320/3000, loss : -416.9626770019531
epoch : 2330/3000, loss : -416.86090087890625
epoch : 2340/3000, loss : -417.4842529296875
epoch : 2350/3000, loss : -417.88226318359375
epoch : 2360/3000, loss : -417.720703125
epoch : 2370/3000, loss : -417.6545715332031
epoch : 2380/3000, loss : -417.9783020019531
epoch : 2390/3000, loss : -418.3420104980469
epoch : 2400/3000, loss : -417.8702697753906
epoch : 2410/3000, loss : -418.5470886230469
epoch : 2420/3000, loss : -418.72998046875
epoch : 2430/3000, loss : -418.8353271484375
epoch : 2440/3000, loss : -418.9163513183594
epoch : 2450/3000, loss : -418.52496337890625
epoch : 2460/3000, loss : -418.5057067871094
epoch : 2470/3000, loss : -418.6322021484375
epoch : 2480/3000, loss : -418.55914306640625
epoch : 2490/3000, loss : -418.803466796875
epoch : 2500/3000, loss : -419.3629455566406
epoch : 2510/3000, loss : -419.2119140625
epoch : 2520/3000, loss : -419.388427734375
epoch : 2530/3000, loss : -419.1024475097656
epoch : 2540/3000, loss : -418.6733093261719
epoch : 2550/3000, loss : -419.01556396484375
epoch : 2560/3000, loss : -419.2578430175781
epoch : 2570/3000, loss : -419.5653076171875
epoch : 2580/3000, loss : -419.0863952636719
epoch : 2590/3000, loss : -418.67889404296875
epoch : 2600/3000, loss : -419.0390930175781
epoch : 2610/3000, loss : -418.9076843261719
epoch : 2620/3000, loss : -418.81072998046875
epoch : 2630/3000, loss : -419.11932373046875
epoch : 2640/3000, loss : -419.5191650390625
epoch : 2650/3000, loss : -419.1033020019531
epoch : 2660/3000, loss : -418.5887451171875
epoch : 2670/3000, loss : -418.7349548339844
epoch : 2680/3000, loss : -419.1938171386719
epoch : 2690/3000, loss : -419.9246520996094
epoch : 2700/3000, loss : -420.0303039550781
epoch : 2710/3000, loss : -419.93212890625
epoch : 2720/3000, loss : -419.70343017578125
epoch : 2730/3000, loss : -420.314208984375
epoch : 2740/3000, loss : -420.3444519042969
epoch : 2750/3000, loss : -420.7531433105469
epoch : 2760/3000, loss : -421.51702880859375
epoch : 2770/3000, loss : -421.25421142578125
epoch : 2780/3000, loss : -421.3046875
epoch : 2790/3000, loss : -421.30206298828125
epoch : 2800/3000, loss : -421.2840881347656
epoch : 2810/3000, loss : -421.5173034667969
epoch : 2820/3000, loss : -421.6015625
epoch : 2830/3000, loss : -421.44171142578125
epoch : 2840/3000, loss : -421.6876525878906
epoch : 2850/3000, loss : -421.489013671875
epoch : 2860/3000, loss : -421.4428405761719
epoch : 2870/3000, loss : -421.7579345703125
epoch : 2880/3000, loss : -422.0687255859375
epoch : 2890/3000, loss : -421.82183837890625
epoch : 2900/3000, loss : -421.74444580078125
epoch : 2910/3000, loss : -422.3063049316406
epoch : 2920/3000, loss : -422.8274230957031
epoch : 2930/3000, loss : -422.4685363769531
epoch : 2940/3000, loss : -422.4205322265625
epoch : 2950/3000, loss : -422.7973327636719
epoch : 2960/3000, loss : -422.500732421875
epoch : 2970/3000, loss : -422.5121765136719
epoch : 2980/3000, loss : -422.5174865722656
epoch : 2990/3000, loss : -422.8204345703125
epoch : 3000/3000, loss : -422.61260986328125

可視化

plt.rcParams["animation.embed_limit"] = 30.0
fig = plt.figure()
ax = Axes3D(fig)
ax.view_init(elev=60, azim=45)

def update(frame):
    ax.cla()
    f_energy = energy_log[frame]
    c = np.percentile(f_energy, q=[20, 80, 90])
    f_energy = np.clip(f_energy, c[0], c[1])
    ax.plot_surface(m1, m2, f_energy.reshape(n,n), cmap="magma")
    ax.set_zlim(c[0], c[1])
    #ax.scatter3D(train_data[:,0], train_data[:,1])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    ax.set_title("epoch : {}".format((frame+1)*100))

anim = FuncAnimation(fig, update, frames=len(energy_log), interval=50)
anim
/tmp/ipykernel_24230/1371740217.py:3: MatplotlibDeprecationWarning: Axes3D(fig) adding itself to the figure is deprecated since 3.4. Pass the keyword argument auto_add_to_figure=False and use fig.add_axes(ax) to suppress this warning. The default value of auto_add_to_figure will change to False in mpl3.5 and True values will no longer work in 3.6.  This is consistent with other Axes classes.
  ax = Axes3D(fig)
</input>